{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.contrib import slim\n",
    "import numpy as np\n",
    "from tqdm import tqdm_notebook\n",
    "from matplotlib import pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 512\n",
    "LR = 2e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def get_data_samples(N):\n",
    "    data = tf.random_uniform([N], minval=0, maxval=4, dtype=tf.int32)\n",
    "    return data\n",
    "\n",
    "def encoder_func(x):\n",
    "    net = x\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "\n",
    "    zmean = slim.fully_connected(net, 2, activation_fn=None)\n",
    "    zlogstd = slim.fully_connected(net, 2, activation_fn=None)\n",
    "\n",
    "    return zmean, zlogstd\n",
    "\n",
    "\n",
    "def decoder_func(z):\n",
    "    net = z\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "\n",
    "    xlogits = slim.fully_connected(net, 4, activation_fn=None)\n",
    "    return xlogits\n",
    "\n",
    "def create_scatter(x_test_labels, eps_test, savepath=None):\n",
    "    plt.figure(figsize=(5,5), facecolor='w')\n",
    "\n",
    "    for i in range(4):\n",
    "        z_out = sess.run(z_inferred, feed_dict={x_real_labels: x_test_labels[i], eps: eps_test})\n",
    "        plt.scatter(z_out[:, 0], z_out[:, 1],  edgecolor='none', alpha=0.5)\n",
    "\n",
    "    plt.xlim(-3, 3); plt.ylim(-3.5, 3.5)\n",
    "\n",
    "    plt.axis('off')\n",
    "    if savepath:\n",
    "        plt.savefig(savepath)\n",
    "\n",
    "encoder = tf.make_template('encoder', encoder_func)\n",
    "decoder = tf.make_template('decoder', decoder_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "eps = tf.random_normal([BATCH_SIZE, 2])\n",
    "x_real_labels = get_data_samples(BATCH_SIZE)\n",
    "x_real = tf.one_hot(x_real_labels, 4)\n",
    "zmean, zlogstd = encoder(x_real)\n",
    "z_inferred = zmean + eps*tf.exp(zlogstd)\n",
    "x_reconstr_logits = decoder(z_inferred)\n",
    "\n",
    "reconstr_err = tf.reduce_sum(\n",
    "    tf.nn.sigmoid_cross_entropy_with_logits(labels=x_real, logits=x_reconstr_logits),\n",
    "    axis=1\n",
    ")\n",
    "\n",
    "KL = tf.reduce_sum(0.5*tf.square(z_inferred) - zlogstd - 0.5, 1)\n",
    "\n",
    "loss = tf.reduce_mean(reconstr_err + KL)\n",
    "optimizer = tf.train.AdamOptimizer(LR)\n",
    "train_op = optimizer.minimize(loss)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x_test_labels = [[i] * BATCH_SIZE for i in range(4)]\n",
    "eps_test = np.random.randn(BATCH_SIZE, 2) \n",
    "\n",
    "outdir = './out_vae'\n",
    "if not os.path.exists(outdir):\n",
    "    os.makedirs(outdir)\n",
    "    \n",
    "progress = tqdm_notebook(range(100000))\n",
    "for i in progress:\n",
    "    ELBO_out, _ = sess.run([loss, train_op])\n",
    "\n",
    "    progress.set_description('ELBO = %.2f' % ELBO_out)\n",
    "    if i % 100 == 0:\n",
    "        create_scatter(x_test_labels, eps_test, savepath=os.path.join(outdir, '%08d.png' % i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
